import torch
from torch import nn

from .model import JointTransformerBlock

class ZImageControlTransformerBlock(JointTransformerBlock):
    def __init__(
        self,
        layer_id: int,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        multiple_of: int,
        ffn_dim_multiplier: float,
        norm_eps: float,
        qk_norm: bool,
        modulation=True,
        block_id=0,
        operation_settings=None,
    ):
        super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
        self.block_id = block_id
        if block_id == 0:
            self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
        self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))

    def forward(self, c, x, **kwargs):
        if self.block_id == 0:
            c = self.before_proj(c) + x
        c = super().forward(c, **kwargs)
        c_skip = self.after_proj(c)
        return c_skip, c

class ZImage_Control(torch.nn.Module):
    def __init__(
        self,
        dim: int = 3840,
        n_heads: int = 30,
        n_kv_heads: int = 30,
        multiple_of: int = 256,
        ffn_dim_multiplier: float = (8.0 / 3.0),
        norm_eps: float = 1e-5,
        qk_norm: bool = True,
        dtype=None,
        device=None,
        operations=None,
        **kwargs
    ):
        super().__init__()
        operation_settings = {"operations": operations, "device": device, "dtype": dtype}

        self.additional_in_dim = 0
        self.control_in_dim = 16
        n_refiner_layers = 2
        self.n_control_layers = 6
        self.control_layers = nn.ModuleList(
            [
                ZImageControlTransformerBlock(
                    i,
                    dim,
                    n_heads,
                    n_kv_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                    qk_norm,
                    block_id=i,
                    operation_settings=operation_settings,
                )
                for i in range(self.n_control_layers)
            ]
        )

        all_x_embedder = {}
        patch_size = 2
        f_patch_size = 1
        x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
        all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder

        self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
        self.control_noise_refiner = nn.ModuleList(
            [
                JointTransformerBlock(
                    layer_id,
                    dim,
                    n_heads,
                    n_kv_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                    qk_norm,
                    modulation=True,
                    z_image_modulation=True,
                    operation_settings=operation_settings,
                )
                for layer_id in range(n_refiner_layers)
            ]
        )

    def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
        patch_size = 2
        f_patch_size = 1
        pH = pW = patch_size
        B, C, H, W = control_context.shape
        control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))

        x_attn_mask = None
        for layer in self.control_noise_refiner:
            control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
        return control_context

    def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
        return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
